from kblam.utils.model_utils import load_model_and_tokenizer
from kblam.utils.eval_utils import answer_question
from kblam.utils.train_utils import get_kb_embd
from kblam.utils.data_utils import load_entities
import argparse
import numpy as np
import torch

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--base_model", type=str, help="Model name")
    parser.add_argument("--hf_token", type=str, help="Huggingface token. Required for Llama models")
    parser.add_argument("--encoder_model", type=str, help="Sentence encoder model name", default="all-MiniLM-L6-v2", choices=["all-MiniLM-L6-v2", "text-embedding-3-large", "ada-embeddings"])
    parser.add_argument("--query_head_ckpt", type=str, help="Path to the trained query head")
    parser.add_argument("--adapter_ckpt", type=str, help="Path to the adapter (KB encoder) checkpoint")
    parser.add_argument("--Q", type=str, help="Question to ask")
    parser.add_argument("--key_embeds_path", type=str, help="Path to the key embedding numpy file")
    parser.add_argument("--value_embeds_path", type=str, help="Path to the value embedding numpy file")
    parser.add_argument("--dataset_path", type=str, help="Path to the dataset json file")
    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()
    model, tokenizer, encoder = load_model_and_tokenizer(args.base_model, args.hf_token, args.encoder_model, 3, args.query_head_ckpt, args.adapter_ckpt)

    kb_config = {
                    'sep_query_head': True,
                    'kb_scale_factor': None,
                    'eval_mode': False,
                    'append_kb_condition': lambda x: (x % 3 == 0),
                    'kb_layer_frequency': 3
                }

    dataset = load_entities(args.dataset_path)
    value_embds = np.load(args.value_embeds).astype('float32')
    key_embd_src = 'key'
    key_embds = np.load(args.key_embeds).astype('float32')


    np.random.seed(1)
    idx = np.random.randint(0, len(dataset), (200,))
    kb_embedding = get_kb_embd(encoder, idx, precomputed_embd=(key_embds, value_embds))
    new_kv_pairs = [
        ('the purpose of Project Alexandria', "to   extract information from videos"),
        ('the purpose of MSFT JFK', "to build more efficient trading platform"),
        ('the description of MSFT JFK', "a trading firm"),
        ('the purpose of Astromech Droids', "to tell you the odds"),
        ('the purpose of KBALM', "to make language models more reliable"),
        ('the most famous dish of Chef Lou', 'egg fried rice'),
        ('the national drink of Scotland', 'hot water'),
        ('the purpose of wine', 'to help you sleep better'),

    ]
    kb_embedding_extended = kb_embedding
    with torch.autograd.no_grad():
        for k, v in new_kv_pairs:
            new_key = encoder.encode_key(k).unsqueeze(0)
            new_value = encoder.encode_val(v).unsqueeze(0)
            kb_embedding_extended = (
                torch.concat([new_key, kb_embedding_extended[0]]),
                torch.concat([new_value, kb_embedding_extended[1]])
            )

    answer = answer_question(model, tokenizer, args.Q, kb_config=kb_config, kb=kb_embedding_extended)
    print(answer)
